import datetime
import json
import logging
from collections import defaultdict
from collections.abc import Iterator
from json import JSONDecodeError
from typing import Optional

from pydantic import BaseModel, ConfigDict

from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
from core.entities.provider_entities import (
    CustomConfiguration,
    ModelSettings,
    SystemConfiguration,
    SystemConfigurationStatus,
)
from core.helper import encrypter
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
from core.model_runtime.entities.model_entities import FetchFrom, ModelType
from core.model_runtime.entities.provider_entities import (
    ConfigurateMethod,
    CredentialFormSchema,
    FormType,
    ProviderEntity,
)
from core.model_runtime.model_providers import model_provider_factory
from core.model_runtime.model_providers.__base.ai_model import AIModel
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
from extensions.ext_database import db
from models.provider import (
    LoadBalancingModelConfig,
    Provider,
    ProviderModel,
    ProviderModelSetting,
    ProviderType,
    TenantPreferredModelProvider,
)

logger = logging.getLogger(__name__)

original_provider_configurate_methods = {}


class ProviderConfiguration(BaseModel):
    """
    Model class for provider configuration.
    """
    tenant_id: str
    provider: ProviderEntity
    preferred_provider_type: ProviderType
    using_provider_type: ProviderType
    system_configuration: SystemConfiguration
    custom_configuration: CustomConfiguration
    model_settings: list[ModelSettings]

    # pydantic configs
    model_config = ConfigDict(protected_namespaces=())

    def __init__(self, **data):
        super().__init__(**data)

        if self.provider.provider not in original_provider_configurate_methods:
            original_provider_configurate_methods[self.provider.provider] = []
            for configurate_method in self.provider.configurate_methods:
                original_provider_configurate_methods[self.provider.provider].append(configurate_method)

        if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
            if (any(len(quota_configuration.restrict_models) > 0
                     for quota_configuration in self.system_configuration.quota_configurations)
                    and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods):
                self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL)

    def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]:
        """
        Get current credentials.

        :param model_type: model type
        :param model: model name
        :return:
        """
        if self.model_settings:
            # check if model is disabled by admin
            for model_setting in self.model_settings:
                if (model_setting.model_type == model_type
                        and model_setting.model == model):
                    if not model_setting.enabled:
                        raise ValueError(f'Model {model} is disabled.')

        if self.using_provider_type == ProviderType.SYSTEM:
            restrict_models = []
            for quota_configuration in self.system_configuration.quota_configurations:
                if self.system_configuration.current_quota_type != quota_configuration.quota_type:
                    continue

                restrict_models = quota_configuration.restrict_models

            copy_credentials = self.system_configuration.credentials.copy()
            if restrict_models:
                for restrict_model in restrict_models:
                    if (restrict_model.model_type == model_type
                            and restrict_model.model == model
                            and restrict_model.base_model_name):
                        copy_credentials['base_model_name'] = restrict_model.base_model_name

            return copy_credentials
        else:
            credentials = None
            if self.custom_configuration.models:
                for model_configuration in self.custom_configuration.models:
                    if model_configuration.model_type == model_type and model_configuration.model == model:
                        credentials = model_configuration.credentials
                        break

            if self.custom_configuration.provider:
                credentials = self.custom_configuration.provider.credentials

            return credentials

    def get_system_configuration_status(self) -> SystemConfigurationStatus:
        """
        Get system configuration status.
        :return:
        """
        if self.system_configuration.enabled is False:
            return SystemConfigurationStatus.UNSUPPORTED

        current_quota_type = self.system_configuration.current_quota_type
        current_quota_configuration = next(
            (q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type),
            None
        )

        return SystemConfigurationStatus.ACTIVE if current_quota_configuration.is_valid else \
            SystemConfigurationStatus.QUOTA_EXCEEDED

    def is_custom_configuration_available(self) -> bool:
        """
        Check custom configuration available.
        :return:
        """
        return (self.custom_configuration.provider is not None
                or len(self.custom_configuration.models) > 0)

    def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]:
        """
        Get custom credentials.

        :param obfuscated: obfuscated secret data in credentials
        :return:
        """
        if self.custom_configuration.provider is None:
            return None

        credentials = self.custom_configuration.provider.credentials
        if not obfuscated:
            return credentials

        # Obfuscate credentials
        return self.obfuscated_credentials(
            credentials=credentials,
            credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas
            if self.provider.provider_credential_schema else []
        )

    def custom_credentials_validate(self, credentials: dict) -> tuple[Provider, dict]:
        """
        Validate custom credentials.
        :param credentials: provider credentials
        :return:
        """
        # get provider
        provider_record = db.session.query(Provider) \
            .filter(
            Provider.tenant_id == self.tenant_id,
            Provider.provider_name == self.provider.provider,
            Provider.provider_type == ProviderType.CUSTOM.value
        ).first()

        # Get provider credential secret variables
        provider_credential_secret_variables = self.extract_secret_variables(
            self.provider.provider_credential_schema.credential_form_schemas
            if self.provider.provider_credential_schema else []
        )

        if provider_record:
            try:
                # fix origin data
                if provider_record.encrypted_config:
                    if not provider_record.encrypted_config.startswith("{"):
                        original_credentials = {
                            "openai_api_key": provider_record.encrypted_config
                        }
                    else:
                        original_credentials = json.loads(provider_record.encrypted_config)
                else:
                    original_credentials = {}
            except JSONDecodeError:
                original_credentials = {}

            # encrypt credentials
            for key, value in credentials.items():
                if key in provider_credential_secret_variables:
                    # if send [__HIDDEN__] in secret input, it will be same as original value
                    if value == '[__HIDDEN__]' and key in original_credentials:
                        credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])

        credentials = model_provider_factory.provider_credentials_validate(
            provider=self.provider.provider,
            credentials=credentials
        )

        for key, value in credentials.items():
            if key in provider_credential_secret_variables:
                credentials[key] = encrypter.encrypt_token(self.tenant_id, value)

        return provider_record, credentials

    def add_or_update_custom_credentials(self, credentials: dict) -> None:
        """
        Add or update custom provider credentials.
        :param credentials:
        :return:
        """
        # validate custom provider config
        provider_record, credentials = self.custom_credentials_validate(credentials)

        # save provider
        # Note: Do not switch the preferred provider, which allows users to use quotas first
        if provider_record:
            provider_record.encrypted_config = json.dumps(credentials)
            provider_record.is_valid = True
            provider_record.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
            db.session.commit()
        else:
            provider_record = Provider(
                tenant_id=self.tenant_id,
                provider_name=self.provider.provider,
                provider_type=ProviderType.CUSTOM.value,
                encrypted_config=json.dumps(credentials),
                is_valid=True
            )
            db.session.add(provider_record)
            db.session.commit()

        provider_model_credentials_cache = ProviderCredentialsCache(
            tenant_id=self.tenant_id,
            identity_id=provider_record.id,
            cache_type=ProviderCredentialsCacheType.PROVIDER
        )

        provider_model_credentials_cache.delete()

        self.switch_preferred_provider_type(ProviderType.CUSTOM)

    def delete_custom_credentials(self) -> None:
        """
        Delete custom provider credentials.
        :return:
        """
        # get provider
        provider_record = db.session.query(Provider) \
            .filter(
            Provider.tenant_id == self.tenant_id,
            Provider.provider_name == self.provider.provider,
            Provider.provider_type == ProviderType.CUSTOM.value
        ).first()

        # delete provider
        if provider_record:
            self.switch_preferred_provider_type(ProviderType.SYSTEM)

            db.session.delete(provider_record)
            db.session.commit()

            provider_model_credentials_cache = ProviderCredentialsCache(
                tenant_id=self.tenant_id,
                identity_id=provider_record.id,
                cache_type=ProviderCredentialsCacheType.PROVIDER
            )

            provider_model_credentials_cache.delete()

    def get_custom_model_credentials(self, model_type: ModelType, model: str, obfuscated: bool = False) \
            -> Optional[dict]:
        """
        Get custom model credentials.

        :param model_type: model type
        :param model: model name
        :param obfuscated: obfuscated secret data in credentials
        :return:
        """
        if not self.custom_configuration.models:
            return None

        for model_configuration in self.custom_configuration.models:
            if model_configuration.model_type == model_type and model_configuration.model == model:
                credentials = model_configuration.credentials
                if not obfuscated:
                    return credentials

                # Obfuscate credentials
                return self.obfuscated_credentials(
                    credentials=credentials,
                    credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas
                    if self.provider.model_credential_schema else []
                )

        return None

    def custom_model_credentials_validate(self, model_type: ModelType, model: str, credentials: dict) \
            -> tuple[ProviderModel, dict]:
        """
        Validate custom model credentials.

        :param model_type: model type
        :param model: model name
        :param credentials: model credentials
        :return:
        """
        # get provider model
        provider_model_record = db.session.query(ProviderModel) \
            .filter(
            ProviderModel.tenant_id == self.tenant_id,
            ProviderModel.provider_name == self.provider.provider,
            ProviderModel.model_name == model,
            ProviderModel.model_type == model_type.to_origin_model_type()
        ).first()

        # Get provider credential secret variables
        provider_credential_secret_variables = self.extract_secret_variables(
            self.provider.model_credential_schema.credential_form_schemas
            if self.provider.model_credential_schema else []
        )

        if provider_model_record:
            try:
                original_credentials = json.loads(
                    provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {}
            except JSONDecodeError:
                original_credentials = {}

            # decrypt credentials
            for key, value in credentials.items():
                if key in provider_credential_secret_variables:
                    # if send [__HIDDEN__] in secret input, it will be same as original value
                    if value == '[__HIDDEN__]' and key in original_credentials:
                        credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])

        credentials = model_provider_factory.model_credentials_validate(
            provider=self.provider.provider,
            model_type=model_type,
            model=model,
            credentials=credentials
        )

        for key, value in credentials.items():
            if key in provider_credential_secret_variables:
                credentials[key] = encrypter.encrypt_token(self.tenant_id, value)

        return provider_model_record, credentials

    def add_or_update_custom_model_credentials(self, model_type: ModelType, model: str, credentials: dict) -> None:
        """
        Add or update custom model credentials.

        :param model_type: model type
        :param model: model name
        :param credentials: model credentials
        :return:
        """
        # validate custom model config
        provider_model_record, credentials = self.custom_model_credentials_validate(model_type, model, credentials)

        # save provider model
        # Note: Do not switch the preferred provider, which allows users to use quotas first
        if provider_model_record:
            provider_model_record.encrypted_config = json.dumps(credentials)
            provider_model_record.is_valid = True
            provider_model_record.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
            db.session.commit()
        else:
            provider_model_record = ProviderModel(
                tenant_id=self.tenant_id,
                provider_name=self.provider.provider,
                model_name=model,
                model_type=model_type.to_origin_model_type(),
                encrypted_config=json.dumps(credentials),
                is_valid=True
            )
            db.session.add(provider_model_record)
            db.session.commit()

        provider_model_credentials_cache = ProviderCredentialsCache(
            tenant_id=self.tenant_id,
            identity_id=provider_model_record.id,
            cache_type=ProviderCredentialsCacheType.MODEL
        )

        provider_model_credentials_cache.delete()

    def delete_custom_model_credentials(self, model_type: ModelType, model: str) -> None:
        """
        Delete custom model credentials.
        :param model_type: model type
        :param model: model name
        :return:
        """
        # get provider model
        provider_model_record = db.session.query(ProviderModel) \
            .filter(
            ProviderModel.tenant_id == self.tenant_id,
            ProviderModel.provider_name == self.provider.provider,
            ProviderModel.model_name == model,
            ProviderModel.model_type == model_type.to_origin_model_type()
        ).first()

        # delete provider model
        if provider_model_record:
            db.session.delete(provider_model_record)
            db.session.commit()

            provider_model_credentials_cache = ProviderCredentialsCache(
                tenant_id=self.tenant_id,
                identity_id=provider_model_record.id,
                cache_type=ProviderCredentialsCacheType.MODEL
            )

            provider_model_credentials_cache.delete()

    def enable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting:
        """
        Enable model.
        :param model_type: model type
        :param model: model name
        :return:
        """
        model_setting = db.session.query(ProviderModelSetting) \
            .filter(
            ProviderModelSetting.tenant_id == self.tenant_id,
            ProviderModelSetting.provider_name == self.provider.provider,
            ProviderModelSetting.model_type == model_type.to_origin_model_type(),
            ProviderModelSetting.model_name == model
        ).first()

        if model_setting:
            model_setting.enabled = True
            model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
            db.session.commit()
        else:
            model_setting = ProviderModelSetting(
                tenant_id=self.tenant_id,
                provider_name=self.provider.provider,
                model_type=model_type.to_origin_model_type(),
                model_name=model,
                enabled=True
            )
            db.session.add(model_setting)
            db.session.commit()

        return model_setting

    def disable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting:
        """
        Disable model.
        :param model_type: model type
        :param model: model name
        :return:
        """
        model_setting = db.session.query(ProviderModelSetting) \
            .filter(
            ProviderModelSetting.tenant_id == self.tenant_id,
            ProviderModelSetting.provider_name == self.provider.provider,
            ProviderModelSetting.model_type == model_type.to_origin_model_type(),
            ProviderModelSetting.model_name == model
        ).first()

        if model_setting:
            model_setting.enabled = False
            model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
            db.session.commit()
        else:
            model_setting = ProviderModelSetting(
                tenant_id=self.tenant_id,
                provider_name=self.provider.provider,
                model_type=model_type.to_origin_model_type(),
                model_name=model,
                enabled=False
            )
            db.session.add(model_setting)
            db.session.commit()

        return model_setting

    def get_provider_model_setting(self, model_type: ModelType, model: str) -> Optional[ProviderModelSetting]:
        """
        Get provider model setting.
        :param model_type: model type
        :param model: model name
        :return:
        """
        return db.session.query(ProviderModelSetting) \
            .filter(
            ProviderModelSetting.tenant_id == self.tenant_id,
            ProviderModelSetting.provider_name == self.provider.provider,
            ProviderModelSetting.model_type == model_type.to_origin_model_type(),
            ProviderModelSetting.model_name == model
        ).first()

    def enable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting:
        """
        Enable model load balancing.
        :param model_type: model type
        :param model: model name
        :return:
        """
        load_balancing_config_count = db.session.query(LoadBalancingModelConfig) \
            .filter(
            LoadBalancingModelConfig.tenant_id == self.tenant_id,
            LoadBalancingModelConfig.provider_name == self.provider.provider,
            LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
            LoadBalancingModelConfig.model_name == model
        ).count()

        if load_balancing_config_count <= 1:
            raise ValueError('Model load balancing configuration must be more than 1.')

        model_setting = db.session.query(ProviderModelSetting) \
            .filter(
            ProviderModelSetting.tenant_id == self.tenant_id,
            ProviderModelSetting.provider_name == self.provider.provider,
            ProviderModelSetting.model_type == model_type.to_origin_model_type(),
            ProviderModelSetting.model_name == model
        ).first()

        if model_setting:
            model_setting.load_balancing_enabled = True
            model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
            db.session.commit()
        else:
            model_setting = ProviderModelSetting(
                tenant_id=self.tenant_id,
                provider_name=self.provider.provider,
                model_type=model_type.to_origin_model_type(),
                model_name=model,
                load_balancing_enabled=True
            )
            db.session.add(model_setting)
            db.session.commit()

        return model_setting

    def disable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting:
        """
        Disable model load balancing.
        :param model_type: model type
        :param model: model name
        :return:
        """
        model_setting = db.session.query(ProviderModelSetting) \
            .filter(
            ProviderModelSetting.tenant_id == self.tenant_id,
            ProviderModelSetting.provider_name == self.provider.provider,
            ProviderModelSetting.model_type == model_type.to_origin_model_type(),
            ProviderModelSetting.model_name == model
        ).first()

        if model_setting:
            model_setting.load_balancing_enabled = False
            model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
            db.session.commit()
        else:
            model_setting = ProviderModelSetting(
                tenant_id=self.tenant_id,
                provider_name=self.provider.provider,
                model_type=model_type.to_origin_model_type(),
                model_name=model,
                load_balancing_enabled=False
            )
            db.session.add(model_setting)
            db.session.commit()

        return model_setting

    def get_provider_instance(self) -> ModelProvider:
        """
        Get provider instance.
        :return:
        """
        return model_provider_factory.get_provider_instance(self.provider.provider)

    def get_model_type_instance(self, model_type: ModelType) -> AIModel:
        """
        Get current model type instance.

        :param model_type: model type
        :return:
        """
        # Get provider instance
        provider_instance = self.get_provider_instance()

        # Get model instance of LLM
        return provider_instance.get_model_instance(model_type)

    def switch_preferred_provider_type(self, provider_type: ProviderType) -> None:
        """
        Switch preferred provider type.
        :param provider_type:
        :return:
        """
        if provider_type == self.preferred_provider_type:
            return

        if provider_type == ProviderType.SYSTEM and not self.system_configuration.enabled:
            return

        # get preferred provider
        preferred_model_provider = db.session.query(TenantPreferredModelProvider) \
            .filter(
            TenantPreferredModelProvider.tenant_id == self.tenant_id,
            TenantPreferredModelProvider.provider_name == self.provider.provider
        ).first()

        if preferred_model_provider:
            preferred_model_provider.preferred_provider_type = provider_type.value
        else:
            preferred_model_provider = TenantPreferredModelProvider(
                tenant_id=self.tenant_id,
                provider_name=self.provider.provider,
                preferred_provider_type=provider_type.value
            )
            db.session.add(preferred_model_provider)

        db.session.commit()

    def extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]:
        """
        Extract secret input form variables.

        :param credential_form_schemas:
        :return:
        """
        secret_input_form_variables = []
        for credential_form_schema in credential_form_schemas:
            if credential_form_schema.type == FormType.SECRET_INPUT:
                secret_input_form_variables.append(credential_form_schema.variable)

        return secret_input_form_variables

    def obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]) -> dict:
        """
        Obfuscated credentials.

        :param credentials: credentials
        :param credential_form_schemas: credential form schemas
        :return:
        """
        # Get provider credential secret variables
        credential_secret_variables = self.extract_secret_variables(
            credential_form_schemas
        )

        # Obfuscate provider credentials
        copy_credentials = credentials.copy()
        for key, value in copy_credentials.items():
            if key in credential_secret_variables:
                copy_credentials[key] = encrypter.obfuscated_token(value)

        return copy_credentials

    def get_provider_model(self, model_type: ModelType,
                           model: str,
                           only_active: bool = False) -> Optional[ModelWithProviderEntity]:
        """
        Get provider model.
        :param model_type: model type
        :param model: model name
        :param only_active: return active model only
        :return:
        """
        provider_models = self.get_provider_models(model_type, only_active)

        for provider_model in provider_models:
            if provider_model.model == model:
                return provider_model

        return None

    def get_provider_models(self, model_type: Optional[ModelType] = None,
                            only_active: bool = False) -> list[ModelWithProviderEntity]:
        """
        Get provider models.
        :param model_type: model type
        :param only_active: only active models
        :return:
        """
        provider_instance = self.get_provider_instance()

        model_types = []
        if model_type:
            model_types.append(model_type)
        else:
            model_types = provider_instance.get_provider_schema().supported_model_types

        # Group model settings by model type and model
        model_setting_map = defaultdict(dict)
        for model_setting in self.model_settings:
            model_setting_map[model_setting.model_type][model_setting.model] = model_setting

        if self.using_provider_type == ProviderType.SYSTEM:
            provider_models = self._get_system_provider_models(
                model_types=model_types,
                provider_instance=provider_instance,
                model_setting_map=model_setting_map
            )
        else:
            provider_models = self._get_custom_provider_models(
                model_types=model_types,
                provider_instance=provider_instance,
                model_setting_map=model_setting_map
            )

        if only_active:
            provider_models = [m for m in provider_models if m.status == ModelStatus.ACTIVE]

        # resort provider_models
        return sorted(provider_models, key=lambda x: x.model_type.value)

    def _get_system_provider_models(self,
                                    model_types: list[ModelType],
                                    provider_instance: ModelProvider,
                                    model_setting_map: dict[ModelType, dict[str, ModelSettings]]) \
            -> list[ModelWithProviderEntity]:
        """
        Get system provider models.

        :param model_types: model types
        :param provider_instance: provider instance
        :param model_setting_map: model setting map
        :return:
        """
        provider_models = []
        for model_type in model_types:
            for m in provider_instance.models(model_type):
                status = ModelStatus.ACTIVE
                if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]:
                    model_setting = model_setting_map[m.model_type][m.model]
                    if model_setting.enabled is False:
                        status = ModelStatus.DISABLED

                provider_models.append(
                    ModelWithProviderEntity(
                        model=m.model,
                        label=m.label,
                        model_type=m.model_type,
                        features=m.features,
                        fetch_from=m.fetch_from,
                        model_properties=m.model_properties,
                        deprecated=m.deprecated,
                        provider=SimpleModelProviderEntity(self.provider),
                        status=status
                    )
                )

        if self.provider.provider not in original_provider_configurate_methods:
            original_provider_configurate_methods[self.provider.provider] = []
            for configurate_method in provider_instance.get_provider_schema().configurate_methods:
                original_provider_configurate_methods[self.provider.provider].append(configurate_method)

        should_use_custom_model = False
        if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
            should_use_custom_model = True

        for quota_configuration in self.system_configuration.quota_configurations:
            if self.system_configuration.current_quota_type != quota_configuration.quota_type:
                continue

            restrict_models = quota_configuration.restrict_models
            if len(restrict_models) == 0:
                break

            if should_use_custom_model:
                if original_provider_configurate_methods[self.provider.provider] == [
                    ConfigurateMethod.CUSTOMIZABLE_MODEL]:
                    # only customizable model
                    for restrict_model in restrict_models:
                        copy_credentials = self.system_configuration.credentials.copy()
                        if restrict_model.base_model_name:
                            copy_credentials['base_model_name'] = restrict_model.base_model_name

                        try:
                            custom_model_schema = (
                                provider_instance.get_model_instance(restrict_model.model_type)
                                .get_customizable_model_schema_from_credentials(
                                    restrict_model.model,
                                    copy_credentials
                                )
                            )
                        except Exception as ex:
                            logger.warning(f'get custom model schema failed, {ex}')
                            continue

                        if not custom_model_schema:
                            continue

                        if custom_model_schema.model_type not in model_types:
                            continue

                        status = ModelStatus.ACTIVE
                        if (custom_model_schema.model_type in model_setting_map
                                and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]):
                            model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model]
                            if model_setting.enabled is False:
                                status = ModelStatus.DISABLED

                        provider_models.append(
                            ModelWithProviderEntity(
                                model=custom_model_schema.model,
                                label=custom_model_schema.label,
                                model_type=custom_model_schema.model_type,
                                features=custom_model_schema.features,
                                fetch_from=FetchFrom.PREDEFINED_MODEL,
                                model_properties=custom_model_schema.model_properties,
                                deprecated=custom_model_schema.deprecated,
                                provider=SimpleModelProviderEntity(self.provider),
                                status=status
                            )
                        )

            # if llm name not in restricted llm list, remove it
            restrict_model_names = [rm.model for rm in restrict_models]
            for m in provider_models:
                if m.model_type == ModelType.LLM and m.model not in restrict_model_names:
                    m.status = ModelStatus.NO_PERMISSION
                elif not quota_configuration.is_valid:
                    m.status = ModelStatus.QUOTA_EXCEEDED

        return provider_models

    def _get_custom_provider_models(self,
                                    model_types: list[ModelType],
                                    provider_instance: ModelProvider,
                                    model_setting_map: dict[ModelType, dict[str, ModelSettings]]) \
            -> list[ModelWithProviderEntity]:
        """
        Get custom provider models.

        :param model_types: model types
        :param provider_instance: provider instance
        :param model_setting_map: model setting map
        :return:
        """
        provider_models = []

        credentials = None
        if self.custom_configuration.provider:
            credentials = self.custom_configuration.provider.credentials

        for model_type in model_types:
            if model_type not in self.provider.supported_model_types:
                continue

            models = provider_instance.models(model_type)
            for m in models:
                status = ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE
                load_balancing_enabled = False
                if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]:
                    model_setting = model_setting_map[m.model_type][m.model]
                    if model_setting.enabled is False:
                        status = ModelStatus.DISABLED

                    if len(model_setting.load_balancing_configs) > 1:
                        load_balancing_enabled = True

                provider_models.append(
                    ModelWithProviderEntity(
                        model=m.model,
                        label=m.label,
                        model_type=m.model_type,
                        features=m.features,
                        fetch_from=m.fetch_from,
                        model_properties=m.model_properties,
                        deprecated=m.deprecated,
                        provider=SimpleModelProviderEntity(self.provider),
                        status=status,
                        load_balancing_enabled=load_balancing_enabled
                    )
                )

        # custom models
        for model_configuration in self.custom_configuration.models:
            if model_configuration.model_type not in model_types:
                continue

            try:
                custom_model_schema = (
                    provider_instance.get_model_instance(model_configuration.model_type)
                    .get_customizable_model_schema_from_credentials(
                        model_configuration.model,
                        model_configuration.credentials
                    )
                )
            except Exception as ex:
                logger.warning(f'get custom model schema failed, {ex}')
                continue

            if not custom_model_schema:
                continue

            status = ModelStatus.ACTIVE
            load_balancing_enabled = False
            if (custom_model_schema.model_type in model_setting_map
                    and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]):
                model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model]
                if model_setting.enabled is False:
                    status = ModelStatus.DISABLED

                if len(model_setting.load_balancing_configs) > 1:
                    load_balancing_enabled = True

            provider_models.append(
                ModelWithProviderEntity(
                    model=custom_model_schema.model,
                    label=custom_model_schema.label,
                    model_type=custom_model_schema.model_type,
                    features=custom_model_schema.features,
                    fetch_from=custom_model_schema.fetch_from,
                    model_properties=custom_model_schema.model_properties,
                    deprecated=custom_model_schema.deprecated,
                    provider=SimpleModelProviderEntity(self.provider),
                    status=status,
                    load_balancing_enabled=load_balancing_enabled
                )
            )

        return provider_models


class ProviderConfigurations(BaseModel):
    """
    Model class for provider configuration dict.
    """
    tenant_id: str
    configurations: dict[str, ProviderConfiguration] = {}

    def __init__(self, tenant_id: str):
        super().__init__(tenant_id=tenant_id)

    def get_models(self,
                   provider: Optional[str] = None,
                   model_type: Optional[ModelType] = None,
                   only_active: bool = False) \
            -> list[ModelWithProviderEntity]:
        """
        Get available models.

        If preferred provider type is `system`:
          Get the current **system mode** if provider supported,
          if all system modes are not available (no quota), it is considered to be the **custom credential mode**.
          If there is no model configured in custom mode, it is treated as no_configure.
        system > custom > no_configure

        If preferred provider type is `custom`:
          If custom credentials are configured, it is treated as custom mode.
          Otherwise, get the current **system mode** if supported,
          If all system modes are not available (no quota), it is treated as no_configure.
        custom > system > no_configure

        If real mode is `system`, use system credentials to get models,
          paid quotas > provider free quotas > system free quotas
          include pre-defined models (exclude GPT-4, status marked as `no_permission`).
        If real mode is `custom`, use workspace custom credentials to get models,
          include pre-defined models, custom models(manual append).
        If real mode is `no_configure`, only return pre-defined models from `model runtime`.
          (model status marked as `no_configure` if preferred provider type is `custom` otherwise `quota_exceeded`)
        model status marked as `active` is available.

        :param provider: provider name
        :param model_type: model type
        :param only_active: only active models
        :return:
        """
        all_models = []
        for provider_configuration in self.values():
            if provider and provider_configuration.provider.provider != provider:
                continue

            all_models.extend(provider_configuration.get_provider_models(model_type, only_active))

        return all_models

    def to_list(self) -> list[ProviderConfiguration]:
        """
        Convert to list.

        :return:
        """
        return list(self.values())

    def __getitem__(self, key):
        return self.configurations[key]

    def __setitem__(self, key, value):
        self.configurations[key] = value

    def __iter__(self):
        return iter(self.configurations)

    def values(self) -> Iterator[ProviderConfiguration]:
        return self.configurations.values()

    def get(self, key, default=None):
        return self.configurations.get(key, default)


class ProviderModelBundle(BaseModel):
    """
    Provider model bundle.
    """
    configuration: ProviderConfiguration
    provider_instance: ModelProvider
    model_type_instance: AIModel

    # pydantic configs
    model_config = ConfigDict(arbitrary_types_allowed=True,
                              protected_namespaces=())
